%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% SVM Decoding
%
% GE_Step1d_clusterstats.m
%
% Statistical evaluation using permutation tests
%
% Modified from GE_SVM_clusterstats_singlesubject.m
%    by SA 2019-08-11
%    Input : Individual subjects' classifier accuracy data, 
%            averaged over all condition pairs
%    Output: Cluster Statistics figure of significant time window 
%            for each ROI separately
%
% Modified to include multiple condition sets and save the cluster
% statistics along with the figures
%   by DS 2021-04
%

close all;
clear all;

svm_dir = '/autofs/space/clive_001/users/adriana/GE_SVM';
scripts_dir = sprintf('%s/scripts', svm_dir);
addpath(genpath(scripts_dir));

addpath(genpath('/autofs/space/clive_001/users/adriana/UAG_SVM/scripts/fusionlab_toolbox'));

% Specify the list of subjects:

SubjectNames =  {'GE_05', 'GE_06', 'GE_07', 'GE_08', 'GE_09', 'GE_10', 'GE_11', 'GE_12', 'GE_13', 'GE_14', 'GE_15', 'GE_16'};
n_subjects = length(SubjectNames);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the name of the analysis type:

analysis_tag = ''
%analysis_tag = '_neighbors_vs_seeds_16Slices_2Rois_SMG_pMTG'
%analysis_tag = '_neighbors_vs_seeds_32slices'

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the Regions of Interest ("ROIs", also called "labels":

roiset_array = {'newSMG', 'newpMTG', 'TempPole', 'FrontalPole', 'CentGyri'};
%roiset_array = {'SMG_and_pMTG', 'SMG', 'pMTG'};
n_roisets = length(roiset_array);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Specify the condition pairs:
condition_tag = 'lvn'; % 'nvs'|'lvn'|'pos'|'sop'|'nvl'|'svn'

switch condition_tag
	case 'nvs'
		%Train all neighbors; test seeds
		categories = {'neighbors'};
		train_conditionA_array = {'11','11','11','11','11','22','22','22','22','33','33','33','44','44','55'};
		train_conditionB_array = {'22','33','44','55','66','33','44','55','66','44','55','66','55','66','66'};
		test_conditionA_array  = { '1', '1', '1', '1', '1', '2', '2', '2', '2', '3', '3', '3', '4', '4', '5'};
		test_conditionB_array  = { '2', '3', '4', '5', '6', '3', '4', '5', '6', '4', '5', '6', '5', '6', '6'};
	case 'lvn'
		%Separately train lexical and nonword neighbors; test seeds
		categories = {'nonwords'; 'words'}
		train_conditionA_array = {'100','100','100','100','100','200','200','200','200','300','300','300','400','400','500';...
									'10','10','10','10','10','20','20','20','20','30','30','30','40','40','50'};
		train_conditionB_array = {'200','300','400','500','600','300','400','500','600','400','500','600','500','600','600';...
									'20','30','40','50','60','30','40','50','60','40','50','60','50','60','60'};
		test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
	case 'pos'
		%Separately train based on differential position; test seeds
		categories = {'initial'; 'vowel'; 'final'}
		train_conditionA_array = {'1001', '1001', '1001', '1001', '1001', '2001', '2001', '2001', '2001', '3001', '3001', '3001', '4001', '4001', '5001';...
									'1002', '1002', '1002', '1002', '1002', '2002', '2002', '2002', '2002', '3002', '3002', '3002', '4002', '4002', '5002';...
									'1003', '1003', '1003', '1003', '1003', '2003', '2003', '2003', '2003', '3003', '3003', '3003', '4003', '4003', '5003'};
		train_conditionB_array = {'2001', '3001', '4001', '5001', '6001', '3001', '4001', '5001', '6001', '4001', '5001', '6001', '5001', '6001', '6001';...
									'2002', '3002', '4002', '5002', '6002', '3002', '4002', '5002', '6002', '4002', '5002', '6002', '5002', '6002', '6002';...
									'2003', '3003', '4003', '5003', '6003', '3003', '4003', '5003', '6003', '4003', '5003', '6003', '5003', '6003', '6003'};
		test_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		test_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
    case 'sop'	
    %Train on hub words; test separately based on change position
		categories{'initial_rev'; 'vowel_rev'; 'final_rev'}
        test_conditionA_array = {'1001', '1001', '1001', '1001', '1001', '2001', '2001', '2001', '2001', '3001', '3001', '3001', '4001', '4001', '5001';...
									'1002', '1002', '1002', '1002', '1002', '2002', '2002', '2002', '2002', '3002', '3002', '3002', '4002', '4002', '5002';...
									'1003', '1003', '1003', '1003', '1003', '2003', '2003', '2003', '2003', '3003', '3003', '3003', '4003', '4003', '5003'};
		test_conditionB_array = {'2001', '3001', '4001', '5001', '6001', '3001', '4001', '5001', '6001', '4001', '5001', '6001', '5001', '6001', '6001';...
									'2002', '3002', '4002', '5002', '6002', '3002', '4002', '5002', '6002', '4002', '5002', '6002', '5002', '6002', '6002';...
									'2003', '3003', '4003', '5003', '6003', '3003', '4003', '5003', '6003', '4003', '5003', '6003', '5003', '6003', '6003'};
		train_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5';...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		train_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6';...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
		condition_tag = 'pos';
    case 'nvl'
		%Train on hub words; test separately on words and nonword neighbors (random subset)
		categories{'nonwords_rev'; 'words_rev'};
		test_conditionA_array = {'100','100','100','100','100','200','200','200','200','300','300','300','400','400','500',...
									'10','10','10','10','10','20','20','20','20','30','30','30','40','40','50'};
		test_conditionB_array = {'200','300','400','500','600','300','400','500','600','400','500','600','500','600','600',...
									'20','30','40','50','60','30','40','50','60','40','50','60','50','60','60'};
		train_conditionA_array = {'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5',...
									'1','1','1','1','1','2','2','2','2','3','3','3','4','4','5'};
		train_conditionB_array = {'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6',...
									'2','3','4','5','6','3','4','5','6','4','5','6','5','6','6'};
    case 'svn'
		%Train on hub words; test on all neighbors (random subset)
		categories{'neighbors_rev'};    
		test_conditionA_array = {'11', '11', '11', '11', '11', '22', '22', '22', '22', '33', '33', '33', '44', '44', '55'}
		test_conditionB_array = {'22', '33', '44', '55', '66', '33', '44', '55', '66', '44', '55', '66', '55', '66', '66'};
		train_conditionA_array = {'1', '1', '1', '1', '1', '2', '2', '2', '2', '3', '3', '3', '4', '4', '5'};
		train_conditionB_array = {'2', '3', '4', '5', '6', '3', '4', '5', '6', '4', '5', '6', '5', '6', '6'};
end

n_condpairs = min(length(train_conditionA_array), length(train_conditionB_array));

% These are for plotting with subplots:
n_hubwords = 6;
n_cols = n_hubwords - 1;
n_rows = n_hubwords - 1;

%cluster_alpha = .001;
%cluster_alpha = 0.05;
%cluster_alpha = 0.5;
%alpha = 0.01;
%alpha = 0.05;
%alpha = 0.5;
num_permutation = 1000;

%alpha_array         = [0.05, 0.1];
%cluster_alpha_array = [0.05, 0.1];
alpha_array         = [0.05];
cluster_alpha_array = [0.05];
n_alpha = min(length(alpha_array), length(cluster_alpha_array));

timepoints=1101;

savefigs=1;
%savefigs=0;

zeroed_data_cell=cell(1,n_subjects);
zeroed_ave_data_cell=cell(1,n_subjects);

%accuracy_all = NaN(n_subjects, n_condpairs,timepoints); %store decoding accuracy for all condition pairs for all subjects

time_tag = datetime('now', 'Format', 'yyyyMMdd-HHmm');

for i_category=1:length(categories)

for i_alpha = 1:n_alpha   % try different cluster statistics parameters
    this_alpha = alpha_array(i_alpha);
    this_cluster_alpha = cluster_alpha_array(i_alpha);

  for i_roi=1:n_roisets
    roiset=char(roiset_array(i_roi));
    
    figure('Position',[0, 0, 1650, 1200]);
        
    for i_condpair=1:n_condpairs

        train_conditionA= char(train_conditionA_array(i_category, i_condpair));
        train_conditionB= char(train_conditionB_array(i_category, i_condpair));
        test_conditionA= char(test_conditionA_array(i_category, i_condpair));
        test_conditionB= char(test_conditionB_array(i_category, i_condpair));

        for i_subject=1:n_subjects

            %subject=char(SubjectName(i_subject));
            subject=char(SubjectNames(i_subject));
            subject_dir = char(strcat(subject, analysis_tag));
            input_dir = sprintf('%s/%s/results/', svm_dir, subject_dir);

            %%Create the "<subject>/figures" subdirectory in case it doesn't exist yet
            %figure_dir = sprintf('%s/%s/figures', svm_dir, subject_dir);
            %[status, msg, msgID] = mkdir(figure_dir);

            %figure(i_subject*100+i_roiset); % separate figure for each ROI (and each subject)

            %%results_folder=sprintf('%s/%s%s/results/',svm_dir,subject,analysis_condition);
            results_dir=sprintf('%s/%s%s/results/',svm_dir,subject,analysis_tag);
        
            % Get "accuracy" and "Time" from file:
            inputfile = sprintf('%s/%s_%s_train%sand%s_test%svs%s_Accuracy.mat', input_dir, subject, roiset,train_conditionA,train_conditionB,test_conditionA,test_conditionB);
            load(inputfile);
            if exist('x')
                accuracy = x;
                Time = y;
            end
            zeroed_data_cell{1,i_subject} = accuracy-50;
            
            accuracy_all(i_subject, i_condpair,:) = accuracy; %store decoding accuracy for all condition pairs for all subjects
           
        end %i_subject
        
        % Cluster size analysis:

        %stat = fl_permtestcluster(zeroed_data_cell,'tail','right','statistic','tstat','verbose',0,'numpermutation',1000,'alpha',0.05,'clusteralpha',0.05);
        stat = fl_permtestcluster(zeroed_data_cell,'tail','right','statistic','tstat','verbose',0,'numpermutation',num_permutation,'alpha',this_alpha,'clusteralpha',this_cluster_alpha);

        % get average accuracy results
        %Av_resultsfile = sprintf('%s/%s_%s_train%sand%s_test%svs%s_Accuracy.mat',results_dir, subject,roiset, train_conditionA, train_conditionB,test_conditionA, test_conditionB);
        %load(Av_resultsfile);
        %%accuracy_ave_zeroed=accuracy_ave-50;
        %accuracy_ave = accuracy;
                
        accuracy_ave = mean(accuracy_all(:, i_condpair,:),1);
       
        % plot cluster results
        
        iA = round(str2double(test_conditionA));
        iB = round(str2double(test_conditionB));
        subplot(n_rows, n_cols, (iA-1)*n_cols +  iB-1);

        %plot(Time,accuracy_ave);
        plot(Time,squeeze(accuracy_ave));
        %shadedErrorBar(Time,accuracy_ave_zeroed,accuracy_sem,'patchSaturation',0.16);

        hold on
        %plot(Time,stat.statmap,'LineWidth',5);

        tndx = find(stat.criticalmap);
        %my_color = [.9 .9 .9];
        %my_color = [.3 .3 .3];
        my_color = [.9 .2 .2];
        %plot(Time(tndx),25*ones(size(tndx)),'.','Color',[.7 .7 .7]);
        plot(Time(tndx),75*ones(size(tndx)),'.','Color',my_color);
        %ylim([25,75]);
        ylabel('Accuracy (%)','fontsize',10)
        xlabel('Time (msec)','fontsize',10)
        yline(50);

        titletext=sprintf('Cluster Stats %s: train %s %s, test %s %s',roiset,train_conditionA,train_conditionB,test_conditionA,test_conditionB);
        title(titletext, 'fontsize', 10,'Interpreter','none');
        axis([-200 1100 0 100]);
        legend('off');

    end %i_condpair
    
    %if savefigs
    %    fig_save_stem = sprintf('ClusterStats_%s_condition_pairs_clusteralpha_%s_alpha_%s',roiset,string(cluster_alpha),string(alpha));
    %
    %    results_dir = sprintf('%s/GE_grand_average_results/%s/%s',svm_dir,analysis_tag(2:end),time_tag);
    %    [status, msg, msgID] = mkdir(results_dir);
    %
    %    results_fig = sprintf('%s/%s.fig', results_dir,fig_save_stem);
    %    savefig(results_fig);
    %    results_png = sprintf('%s/%s.png', results_dir,fig_save_stem);
    %    print(gcf,[results_png],'-dpng','-r300');
    %
    %end % if savefigs

    % Cluster Statistics for data averaged over all condition pairs
    % (separately for each subject):
    
    for i_subject=1:n_subjects
        zeroed_ave_data_cell{1,i_subject} = squeeze(mean(accuracy_all(i_subject, :, :),2)) - 50;
    end % i_subject
    accuracy_all_ave = mean(mean(accuracy_all,1),2);
    accuracy_all_std = std(mean(accuracy_all,2),1);
    accuracy_all_sem = accuracy_all_std/sqrt(n_subjects);
    
    stat_ave = fl_permtestcluster(zeroed_ave_data_cell,'tail','right','statistic','tstat','verbose',0,'numpermutation',num_permutation,'alpha',this_alpha,'clusteralpha',this_cluster_alpha);
    
    % Plot the mean accuracy cluster stats in the lower left corner of the figure:
        
    subplot(n_rows, n_cols, (n_rows-1)*n_cols + 1);
    plot(Time,squeeze(accuracy_all_ave));
    shadedErrorBar(Time,accuracy_all_ave,accuracy_all_sem,'patchSaturation',0.16);
    hold on
    %plot(Time,stat_ave.statmap,'LineWidth',5);
    tndx = find(stat_ave.criticalmap);
    %my_color = [.9 .2 .2];
    %plot(Time(tndx),75*ones(size(tndx)),'.','Color',my_color);
    plot(Time(tndx),65*ones(size(tndx)),'.','Color',my_color);
    ylabel('Accuracy (%)','fontsize',10)
    xlabel('Time (msec)','fontsize',10)
    yline(50);

    plot_title = analysis_tag;
    maxlength = 28; % cutoff for too long title information 
    if length(plot_title) > maxlength
        plot_title = plot_title(1:maxlength);
    end
    titletext = sprintf('GE%s: %s (cluster alpha %s, alpha %s)',plot_title, roiset, string(this_cluster_alpha), string(this_alpha));
    %titletext=sprintf('Cluster Stats %s: cluster alpha=%5.3f, alpha=%5.3f',roiset, cluster_alpha, alpha);
    title(titletext, 'fontsize', 10,'Interpreter','none');
    %axis([-200 1100 0 100]);
    axis([-200 1100 25 75]);
    legend('off');

    if savefigs
        fig_save_stem = sprintf('ClusterStats_%s_allcondpairs_clusteralpha_%s_alpha_%s',roiset,string(this_cluster_alpha),string(this_alpha));
        results_dir = sprintf('%s/GE_grand_average_results/%s/%s',svm_dir,categories{i_category},time_tag);
        [status, msg, msgID] = mkdir(results_dir);
        results_fig = sprintf('%s/%s.fig', results_dir,fig_save_stem);
        savefig(results_fig);
        results_png = sprintf('%s/%s.png', results_dir,fig_save_stem);
        print(gcf,[results_png],'-dpng','-r300');
    end % if savefigs
    
    
    % Plot the mean accuracy cluster stats, nicely formatted with gray highlighting:
        
    figure(i_roi*100+i_alpha);
    plot(Time,squeeze(accuracy_all_ave), 'Color', 'b');
    %shadedErrorBar(Time,accuracy_all_ave,accuracy_all_sem,'patchSaturation',0.16);
    %shadedErrorBar(Time,accuracy_all_ave,accuracy_all_std,'patchSaturation',0.16);
    shadedErrorBar(Time,accuracy_all_ave,accuracy_all_std,'lineprops','-b', 'patchSaturation', 0.16);
    
    hold on
    %tndx = find(stat_ave.criticalmap);
    %plot(Time(tndx),65*ones(size(tndx)),'.','Color',my_shading);
    ylabel('Decoding Accuracy (%)','fontsize',10)
    xlabel('Time (msec)','fontsize',10)
    %axis([-200 1100 0 100]);
    %axis([-200 1100 25 75]);
    x_min = -200;
    x_max = 1100;
    y_min = 30;
    y_max = 70;
    axis([x_min x_max y_min y_max]);
    xticks(-100:100:1000);
    yticks(y_min:5:y_max);
    yline(50);
    xline(0);
    legend('off');
    titletext = roiset;
    title(titletext, 'fontsize', 10,'Interpreter','none');

    my_shading = [.5 .5 .5];
    for i_clust = 1:length(stat_ave.clusters)
        iic = cell2mat(stat_ave.clusters(i_clust));
        harea = area([Time(iic(1)) Time(iic(end))], [100 100], 'FaceColor', my_shading, 'LineStyle', 'none');
        alpha(0.2);
    end % i_clust
    hold off;
    
    if savefigs
        fig_save_stem = sprintf('ClusterStats_%s_clusteralpha_%s_alpha_%s',roiset,string(this_cluster_alpha),string(this_alpha));
        results_dir = sprintf('%s/GE_grand_average_results/%s/%s',svm_dir,categories{i_category},time_tag);
        [status, msg, msgID] = mkdir(results_dir);
        results_fig = sprintf('%s/%s.fig', results_dir,fig_save_stem);
        savefig(results_fig);
        results_png = sprintf('%s/%s.png', results_dir,fig_save_stem);
        print(gcf,[results_png],'-dpng','-r300');
        stat_filename = sprintf('%s/cluster_stats_%s', results_dir, roiset_array{i_roi});
        save(stat_filename, 'stat_ave');
    end % if savefigs

  end %i_roi

end %i_alpha

end %i_category



% figures=[1:n_figures];
% figSize = [58, 48];            % [width, height]
% figUnits = 'Centimeters';
% for f = 1:numel(figures)
%       fig = figures(f);
%       % Resize the figure
%
%       % Output the figure
%       fig_name=cell2str(fignames{1,f});
%       filename = sprintf('%s.jpg', fig_name);
%       print( fig, '-dpdf', filename );
% end

% %cluster size analysis
% stat = fl_permtestcluster(data,'tail','right','statistic','tstat','verbose',0,'numpermutation',1000,'alpha',0.05,'clusteralpha',0.05);
% figure;
% plot(time,stat.statmap);
% hold on;
% tndx = find(stat.criticalmap);
% plot(time(tndx),25*ones(size(tndx)),'b*')

% %fdr analysis
% stat = fl_permtest(data,'tail','right','statistic','tstat','verbose',0,'numpermutation',1000,'alpha',0.05);
% figure;
% plot(time,stat.statmap);
% hold on;
% tndx = find(stat.FDR.criticalmap);
% plot(time(tndx),25*ones(size(tndx)),'b*')
%
%

